import math
import random
import numpy as np
from sklearn.cluster import KMeans
from skimage import io, color
import kociemba
import packages.move as move
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
class Cube:
    # the squares on the image taken which contain one piece of the cube
    cube_string = []
    pixel_groups = [  # the pixel positions each face of the cube shows up on the camera
        (163, 0, 185, 34),
        (341, 0, 430, 34),
        (568, 0, 627, 34),
        (71, 140, 160, 300),
        (341, 140, 430, 300),
        (624, 140, 700, 300),
        (154, 413, 190, 472),
        (333, 422, 424, 472),
        (573, 417, 622, 476),
    ]
    # rotates any list by 90 degrees as if it were a cube face.
    # can accept any number of rotations, positive / negative and recursively calls itself to complete them
    def rotate_slist(self, side_list, rotate_num):
        rotate_num = rotate_num % 4
        if rotate_num != 0:
            new_side_list = [
                side_list[6], side_list[3], side_list[0],
                side_list[7], side_list[4], side_list[1],
                side_list[8], side_list[5], side_list[2],
            ]
            rotate_num -= 1
            if rotate_num != 0:
                new_side_list = self.rotate_slist(new_side_list, rotate_num)
            return new_side_list
        else:
            return side_list
    def rotate_side(self, side, rotate_num):
        rotate_num = rotate_num % 4
        if rotate_num != 0:
            side_dict = {
                'U': 0,
                'R': 1,
                'F': 2,
                'D': 3,
                'L': 4,
                'B': 5,
            }
            # brute force way of accessing the right faces. It's terrible, but it works
            # structured by each subface going clockwise represented by the face number
            # the second is the orientation of the 3 stickers on the subface
            # the third is the orientation of the move if the face is rotated clockwise
            sub_face_dict = {
                'U': [('B', 0, 1), ('R', 0, 1), ('F', 0, 1), ('L', 0, 1)],
                'R': [('U', 1, 1), ('B', 3, 1), ('D', 1, 1), ('F', 1, 1)],
                'F': [('U', 2, 1), ('R', 3, 1), ('D', 0, 1), ('L', 1, 1)],
                'D': [('F', 2, 1), ('B', 2, 1), ('D', 2, 1), ('F', 2, 1)],
                'L': [('U', 3, 1), ('F', 3, 1), ('D', 3, 1), ('B', 1, 1)],
                'B': [('U', 0, 1), ('L', 3, 1), ('D', 2, 1), ('R', 1, 1)],
            }
            sub_face_access_dict = {
                0: [0, 1, 2],
                1: [2, 5, 8],
                2: [8, 7, 6],
                3: [6, 4, 0],
            }
            # rotate the major face
            main_face = self.cube_string[side_dict[side] * 9:side_dict[side] * 9 + 9]
            main_face = self.rotate_slist(main_face, 1)
            # create a list of the sub faces
            sub_faces = sub_face_dict[side]
            sub_face_list = []
            for face in sub_faces:
                sub_face = self.cube_string[side_dict[face[0]] * 9:side_dict[face[0]] * 9 + 9]
                column_indices = sub_face_access_dict[face[1]]
                if face[2] == 1:
                    column_indices.reverse()
                sub_face_list += [sub_face[i] for i in column_indices]
            # move the sub face list by 3
            for face in range(3):
                sub_face_list.insert(0, sub_face_list.pop(-1))
            # rebuild the cube_string
            # add back the main face
            new_string = [i for i in self.cube_string]
            for i, new in enumerate(main_face):
                new_string[side_dict[side] * 9 + i] = new
            # add back each of the sub faces
            for i, face in enumerate(sub_faces):
                column_indices = sub_face_access_dict[face[1]]
                if face[2] == 1:
                    column_indices.reverse()
                column = sub_face_list[i * 3:i * 3 + 3]
                for j in range(3):
                    replace_index = side_dict[face[0]] * 9 + column_indices[j]
                    new_string[replace_index] = column[j]
            self.cube_string = ""
            for i in new_string:
                self.cube_string += i
    def make_cube_string(self, colours):
        # restructure the list, so it is in the standard order for cubestrings.
        ordered_colours = []
        # starts with d, l, b, u, r, f
        # is set to   u, r, f, d, l, b
        ordered_colours += self.rotate_slist(colours[3], 0)
        ordered_colours += self.rotate_slist(colours[4], 1)
        ordered_colours += self.rotate_slist(colours[5], 1)
        ordered_colours += self.rotate_slist(colours[0], 1)
        ordered_colours += self.rotate_slist(colours[1], 2)
        ordered_colours += self.rotate_slist(colours[2], 2)
        # Convert RGB to Lab
        lab_pixels = [color.rgb2lab(np.array([[rgb]]))[0][0] for rgb in ordered_colours]
        new = []
        for i in lab_pixels:
            new.append(
                [i[0]*5,
                i[1],
                i[2]]
            )
        lab_pixels = new
        fig1 = plt.figure()
        fig2 = plt.figure()
        # Add a 3D subplot
        lab = fig1.add_subplot(111, projection='3d')
        rgb = fig2.add_subplot(111, projection='3d')
        # Plot the 3D scatter graph
        lab.scatter(
            [i[0] for i in lab_pixels] + [7000],
            [i[1] for i in lab_pixels] + [7000],
            [i[2] for i in lab_pixels] + [7000],
            c=[(rgba[0] / 256, rgba[1] / 256, rgba[2] / 256) for rgba in ordered_colours] + [(0,0,0)],
        )
        rgb.scatter(
            [i[0] for i in ordered_colours],
            [i[1] for i in ordered_colours],
            [i[2] for i in ordered_colours],
            c=[(rgba[0] / 256, rgba[1] / 256, rgba[2] / 256) for rgba in ordered_colours],
        )
        plt.autoscale = False
        # plt.show()
        # Display the graph
        print(lab_pixels)
        # Apply k-means clustering
        # kmeans = KMeans(n_clusters=6, random_state=42)
        # kmeans.fit(lab_pixels)
        # labels = kmeans.labels_
        def remove_colour_references(colour_indices, dist_sort, dist_points):
            # remove all items in the two distance lists referencing either of these colours as colour2
            r_num = 0
            for i in range(len(dist_sort)):
                for j in colour_indices:
                    if dist_sort[i - r_num]['col1'] == j:
                        dist_sort.pop(i - r_num)
                        r_num += 1
                        break
            for i in dist_points:
                for remove_index in colour_indices:
                    for d, data_entry in enumerate(i):
                        if data_entry['col2'] == remove_index:
                            i.pop(d)
                            break
            return dist_sort, dist_points
        # add indices to ordered_colours for later
        ordered_lab = [{'ind': i, 'val': colour} for i, colour in enumerate(lab_pixels)]
        groups = []
        # generate a list of the distances between all the colours
        # stores the distances from the found points to all the remaining points
        colour_distances_sorted = []  # sorted in terms of distances - used for initial search
        colour_distances_points = []  # sorted in terms of points - used for searching distances
        # each item looks like: (source_point_index, look_point_index, distance)
        for c1i, colour1 in enumerate(ordered_lab):
            colour_distances_points.append([])
            for c2i, colour2 in enumerate(ordered_lab):
                distance = (  # don't bother with the square root because it's too hard and we don't really need it
                    math.pow(colour1['val'][0]-colour2['val'][0], 2) +
                    math.pow(colour1['val'][1]-colour2['val'][1], 2) +
                    math.pow(colour1['val'][2]-colour2['val'][2], 2)
                )
                if distance != 0:  # if the distance is zero then we are comparing a point to itself.
                    item = {'col1': c1i, 'col2': c2i, 'dist': distance}
                    colour_distances_sorted.append(item)
                    colour_distances_points[-1].append(item)
            # sort the dist-points list
            colour_distances_points[-1].sort(key=lambda row: row['dist'])
        colour_distances_sorted.sort(key=lambda row: row['dist'])
        for i in range(6):
            # my janky grouping algorithm
            # first, get the first and second colour from colour_distances
            new_group = [
                colour_distances_sorted[0]['col1'],
                colour_distances_sorted[0]['col2']
            ]
            # remove references of colours from the new group
            colour_distances_sorted, colour_distances_points = \
                remove_colour_references(new_group, colour_distances_sorted, colour_distances_points)
            # loop until we find 9 colours
            for _ in range(7):
                closest_points = []
                for a in new_group:
                    closest_points.append(colour_distances_points[a][0])
                closest_points.sort(key=lambda row: row['dist'])
                new_group.append(closest_points[0]['col2'])
                colour_distances_sorted, colour_distances_points = \
                    remove_colour_references([new_group[-1]], colour_distances_sorted, colour_distances_points)
            groups.append(new_group)
        # convert the group indices into a cube_string
        cube_nums = ['n' for _ in range(6 * 9)]
        for i, group in enumerate(groups):
            for item in group:
                cube_nums[item] = i
        print(cube_nums)
        # # pack the list into a string
        # self.cube_string = ""
        # for i in cube_string:
        #     self.cube_string += i
        # print(self.cube_string)
        # converts the numbered labels into a cube string
        cube_string = ['n' for _ in range(6 * 9)]
        labels_lookup = {
            cube_nums[9 * 0 + 4]: 'U',
            cube_nums[9 * 1 + 4]: 'R',
            cube_nums[9 * 2 + 4]: 'F',
            cube_nums[9 * 3 + 4]: 'D',
            cube_nums[9 * 4 + 4]: 'L',
            cube_nums[9 * 5 + 4]: 'B',
        }
        print(labels_lookup)
        for i, label in enumerate(cube_nums):
            cube_string[i] = labels_lookup[label]
        self.cube_string = ""
        for i in cube_string:
            self.cube_string += i
        print(self.cube_string)
        # Define the data
        colors = ordered_colours
        labels = cube_string
        num_boxes = len(colors)
        rows = 9
        cols = num_boxes // rows
        # Create a figure and axis
        fig, ax = plt.subplots(rows, cols, figsize=(8, 4))
        # Flatten the axis if it's a single row or column
        if rows == 1 or cols == 1:
            ax = ax.flatten()
        # Plot each box with its corresponding color and label
        for i in range(rows):
            for j in range(cols):
                ax[i][j].add_patch(
                    plt.Rectangle((0, 0), 1, 1,
                                  color=[a/256 for a in colors[i*cols + j]],
                                  alpha=0.7))
                ax[i][j].text(0.5, 0.5, labels[i*cols + j], fontsize=20, ha='center', va='center', color='white')
                # Remove the axis
                ax[i][j].axis('off')
        # Adjust layout
        plt.tight_layout()
        # Show the plot
        plt.show()
        #print(kociemba.solve(self.cube_string))